{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e49eb9d9-7680-4d81-a1a6-5c440f97b277",
   "metadata": {},
   "source": [
    "# Creating the AGNEWs `LabelledSimpleDataset`"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ccca313-62b2-4eaa-a494-16fd19a89603",
   "metadata": {},
   "source": [
    "In this notebook, we take the AGNEWs dataset ([original source](https://www.kaggle.com/datasets/amananandrai/ag-news-classification-dataset)) and turn it into a `LabelledSimpleDataset` that we ultimately run the `DiffPrivateSimpleDatasetPack` on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a344e90-a265-476c-8eec-6c77b45b2f85",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.llama_dataset.simple import (\n",
    "    LabelledSimpleDataExample,\n",
    "    LabelledSimpleDataset,\n",
    ")\n",
    "from llama_index.core.llama_dataset.base import CreatedBy, CreatedByType\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0903c9e8-bbde-4488-896e-9a83ee27af01",
   "metadata": {},
   "source": [
    "### Load data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d611b3c1-57dd-4d74-a1a6-529b6950abab",
   "metadata": {},
   "source": [
    "The dataset is also available from our public Dropbox."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33024450-806e-473d-b395-c600e7db328d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2024-03-18 12:21:01--  https://www.dropbox.com/scl/fi/wzcuxuv2yo8gjp5srrslm/train.csv?rlkey=6kmofwjvsamlf9dj15m34mjw9&dl=1\n",
      "Resolving www.dropbox.com (www.dropbox.com)... 162.125.11.18\n",
      "Connecting to www.dropbox.com (www.dropbox.com)|162.125.11.18|:443... connected.\n",
      "HTTP request sent, awaiting response... 302 Found\n",
      "Location: https://uc82fce8c47868a649f55a4d5de7.dl.dropboxusercontent.com/cd/0/inline/CPVWu3TnpuUIfZZJZ4fGfwlbBAh0Yq2rn-Z-B86N2nUIVfeX0IIlkkPsiHLXfV0ZAqIL2jTDl0tW4s72KlwbWAgtPq9RP6AWKmsm4hymekDGtzH7_fq6i09hKhZfI67nrgv9_R_7TsI4mAc9XwhuIdQx/file?dl=1# [following]\n",
      "--2024-03-18 12:21:02--  https://uc82fce8c47868a649f55a4d5de7.dl.dropboxusercontent.com/cd/0/inline/CPVWu3TnpuUIfZZJZ4fGfwlbBAh0Yq2rn-Z-B86N2nUIVfeX0IIlkkPsiHLXfV0ZAqIL2jTDl0tW4s72KlwbWAgtPq9RP6AWKmsm4hymekDGtzH7_fq6i09hKhZfI67nrgv9_R_7TsI4mAc9XwhuIdQx/file?dl=1\n",
      "Resolving uc82fce8c47868a649f55a4d5de7.dl.dropboxusercontent.com (uc82fce8c47868a649f55a4d5de7.dl.dropboxusercontent.com)... 162.125.11.15\n",
      "Connecting to uc82fce8c47868a649f55a4d5de7.dl.dropboxusercontent.com (uc82fce8c47868a649f55a4d5de7.dl.dropboxusercontent.com)|162.125.11.15|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 29066993 (28M) [application/binary]\n",
      "Saving to: ‘data/agnews/train.csv’\n",
      "\n",
      "data/agnews/train.c 100%[===================>]  27.72M  18.4MB/s    in 1.5s    \n",
      "\n",
      "2024-03-18 12:21:04 (18.4 MB/s) - ‘data/agnews/train.csv’ saved [29066993/29066993]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!mkdir -p \"data/agnews/\"\n",
    "!wget \"https://www.dropbox.com/scl/fi/wzcuxuv2yo8gjp5srrslm/train.csv?rlkey=6kmofwjvsamlf9dj15m34mjw9&dl=1\" -O \"data/agnews/train.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eba32ef6-fb27-4a3a-81e7-63f86a720631",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(\"./data/agnews/train.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ab82e0d-8c9c-427c-8523-51cbba53963d",
   "metadata": {},
   "outputs": [],
   "source": [
    "class_to_label = {1: \"World\", 2: \"Sports\", 3: \"Business\", 4: \"Sci/Tech\"}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f23d95d9-1092-4ace-93c7-f1413bacdcd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"Label\"] = df[\"Class Index\"].map(class_to_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1028fdb-6d73-495c-8b31-e3ae33f47685",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Class Index</th>\n",
       "      <th>Title</th>\n",
       "      <th>Description</th>\n",
       "      <th>Label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3</td>\n",
       "      <td>Wall St. Bears Claw Back Into the Black (Reuters)</td>\n",
       "      <td>Reuters - Short-sellers, Wall Street's dwindli...</td>\n",
       "      <td>Business</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>Carlyle Looks Toward Commercial Aerospace (Reu...</td>\n",
       "      <td>Reuters - Private investment firm Carlyle Grou...</td>\n",
       "      <td>Business</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>Oil and Economy Cloud Stocks' Outlook (Reuters)</td>\n",
       "      <td>Reuters - Soaring crude prices plus worries\\ab...</td>\n",
       "      <td>Business</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>Iraq Halts Oil Exports from Main Southern Pipe...</td>\n",
       "      <td>Reuters - Authorities have halted oil export\\f...</td>\n",
       "      <td>Business</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3</td>\n",
       "      <td>Oil prices soar to all-time record, posing new...</td>\n",
       "      <td>AFP - Tearaway world oil prices, toppling reco...</td>\n",
       "      <td>Business</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   Class Index                                              Title  \\\n",
       "0            3  Wall St. Bears Claw Back Into the Black (Reuters)   \n",
       "1            3  Carlyle Looks Toward Commercial Aerospace (Reu...   \n",
       "2            3    Oil and Economy Cloud Stocks' Outlook (Reuters)   \n",
       "3            3  Iraq Halts Oil Exports from Main Southern Pipe...   \n",
       "4            3  Oil prices soar to all-time record, posing new...   \n",
       "\n",
       "                                         Description     Label  \n",
       "0  Reuters - Short-sellers, Wall Street's dwindli...  Business  \n",
       "1  Reuters - Private investment firm Carlyle Grou...  Business  \n",
       "2  Reuters - Soaring crude prices plus worries\\ab...  Business  \n",
       "3  Reuters - Authorities have halted oil export\\f...  Business  \n",
       "4  AFP - Tearaway world oil prices, toppling reco...  Business  "
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b9b4f69-5233-447e-80db-26abd2db9df8",
   "metadata": {},
   "source": [
    "### Create LabelledSimpleDataExample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed53763f-b362-481f-b2a6-ac940675c650",
   "metadata": {},
   "outputs": [],
   "source": [
    "examples = []\n",
    "for index, row in df.iterrows():\n",
    "    example = LabelledSimpleDataExample(\n",
    "        reference_label=row[\"Label\"],\n",
    "        text=f\"{row['Title']} {row['Description']}\",\n",
    "        text_by=CreatedBy(type=CreatedByType.HUMAN),\n",
    "    )\n",
    "    examples.append(example)\n",
    "\n",
    "simple_dataset = LabelledSimpleDataset(examples=examples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "729d3561-b0f9-4dc1-9bbe-9913f5114095",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>reference_label</th>\n",
       "      <th>text</th>\n",
       "      <th>text_by</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Business</td>\n",
       "      <td>Wall St. Bears Claw Back Into the Black (Reute...</td>\n",
       "      <td>human</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Business</td>\n",
       "      <td>Carlyle Looks Toward Commercial Aerospace (Reu...</td>\n",
       "      <td>human</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Business</td>\n",
       "      <td>Oil and Economy Cloud Stocks' Outlook (Reuters...</td>\n",
       "      <td>human</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Business</td>\n",
       "      <td>Iraq Halts Oil Exports from Main Southern Pipe...</td>\n",
       "      <td>human</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Business</td>\n",
       "      <td>Oil prices soar to all-time record, posing new...</td>\n",
       "      <td>human</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  reference_label                                               text text_by\n",
       "0        Business  Wall St. Bears Claw Back Into the Black (Reute...   human\n",
       "1        Business  Carlyle Looks Toward Commercial Aerospace (Reu...   human\n",
       "2        Business  Oil and Economy Cloud Stocks' Outlook (Reuters...   human\n",
       "3        Business  Iraq Halts Oil Exports from Main Southern Pipe...   human\n",
       "4        Business  Oil prices soar to all-time record, posing new...   human"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "simple_dataset.to_pandas()[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58cea2f3-73c8-4300-989c-c4d3cc3a6bfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "simple_dataset.save_json(\"agnews.json\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "diff-privacy",
   "language": "python",
   "name": "diff-privacy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
